#!/usr/bin/env python
import numpy as np
import pandas as pd
import argparse
import yaml
import json
import sys

RG_MIDS = {"1.5–3.0": 2.25, "3.0–5.0": 4.0, "5.0–8.0": 6.5, "8.0–12.0": 10.0}


def load_plateau(path):
    P = pd.read_csv(path)
    if "claimable" in P.columns:
        P = P[P["claimable"] == True].copy()
    else:
        P = P.copy()

    P["RG_mid"] = P["R_G_bin"].map(RG_MIDS)

    # Approximate 1σ from 16–84% CI
    if {"A_theta_CI_low", "A_theta_CI_high"}.issubset(P.columns):
        sig = (P["A_theta_CI_high"] - P["A_theta_CI_low"]) / 2.0
    else:
        sig = pd.Series(1e-3, index=P.index)

    sig = sig.replace([np.inf, -np.inf], np.nan)
    if (sig > 0).any():
        sig = sig.fillna(sig[sig > 0].median())
        sig[sig <= 0] = sig[sig > 0].min()
    else:
        sig[:] = 1e-3

    P["sigma"] = sig
    return P


def wls_fit(X, y, w):
    """
    Weighted least squares with simple AIC:
        AIC = n * ln(RSS / n) + 2k
    X: (n, p) without intercept column.
    """
    X1 = np.column_stack([np.ones(len(X)), X])
    Wsqrt = np.sqrt(np.clip(w, 1e-12, None))
    Xw = X1 * Wsqrt[:, None]
    yw = y * Wsqrt
    try:
        beta = np.linalg.lstsq(Xw, yw, rcond=None)[0]
    except Exception:
        return np.array([np.nan] * X1.shape[1]), np.inf, np.inf

    yhat = X1 @ beta
    rss = float(np.sum((y - yhat) ** 2 * w))
    n = len(y)
    k = X1.shape[1]
    # If rss is tiny, guard log
    aic = n * np.log(max(rss / max(n, 1), 1e-12)) + 2 * k
    return beta, rss, aic


def model_compare(P, Xtab, key_frac="frac_x_gt", min_lenses=0, collin_max=1.0):
    """
    Compare size-only vs size+activation within each mass bin.
    Returns summed ΔAIC across bins and average slope vs activation proxy.
    """
    M = pd.merge(P, Xtab, on=["Mstar_bin", "R_G_bin"], how="inner")
    used_bins = []
    dAIC_sum = 0.0
    slopes = []

    for mb, g in M.groupby("Mstar_bin"):
        # optional lens-count floor
        if "n_lenses" in g.columns and min_lenses > 0:
            g = g[g["n_lenses"] >= min_lenses].copy()

        # need at least 3 size stacks to fit intercept + RG + activation
        if len(g) < 3 or key_frac not in g.columns:
            continue

        y = g["A_theta"].to_numpy()
        w = (1.0 / np.maximum(g["sigma"].to_numpy(), 1e-6) ** 2)

        X_size = g["RG_mid"].to_numpy().reshape(-1, 1)
        xk = g[key_frac].to_numpy().reshape(-1, 1)

        # optional collinearity guard
        if np.std(X_size) < 1e-12 or np.std(xk) < 1e-12:
            continue
        corr = float(np.corrcoef(X_size[:, 0], xk[:, 0])[0, 1])
        if np.isfinite(corr) and abs(corr) > collin_max:
            continue

        # size-only model
        _, _, aic_size = wls_fit(X_size, y, w)
        # size+activation model
        X_both = np.column_stack([X_size, xk])
        _, _, aic_act = wls_fit(X_both, y, w)

        if not np.isfinite(aic_size) or not np.isfinite(aic_act):
            continue

        dAIC = aic_size - aic_act
        dAIC_sum += dAIC

        # simple (unweighted) slope of A_theta vs activation proxy
        slope = np.cov(g[key_frac], g["A_theta"], bias=True)[0, 1] / (np.var(g[key_frac]) + 1e-12)
        slopes.append(float(slope))
        used_bins.append(mb)

    if len(used_bins) == 0:
        return dict(n=0, dAICc=np.nan, slope_vs_frac=np.nan, used_mass_bins=[])

    return dict(
        n=len(used_bins),
        dAICc=float(dAIC_sum),          # now plain AIC difference
        slope_vs_frac=float(np.nanmean(slopes)),
        used_mass_bins=used_bins,
    )


def mc_out_gt_mid(P, draws=100000, seed=42):
    rng = np.random.default_rng(seed)
    res = {}
    for mb, g in P.groupby("Mstar_bin"):
        need = {}
        for _, r in g.iterrows():
            need[r["RG_mid"]] = (r["A_theta"], r["sigma"])
        if not all(k in need for k in [4.0, 6.5, 10.0]):
            res[mb] = dict(P_out_gt_mid=np.nan, Delta_out_mid=np.nan)
            continue
        a4 = rng.normal(need[4.0][0], need[4.0][1], draws)
        a65 = rng.normal(need[6.5][0], need[6.5][1], draws)
        a10 = rng.normal(need[10.0][0], need[10.0][1], draws)
        dom = a10 - 0.5 * (a4 + a65)
        res[mb] = dict(P_out_gt_mid=float((dom > 0).mean()), Delta_out_mid=float(dom.mean()))
    return res


def main():
    ap = argparse.ArgumentParser()
    ap.add_argument("--plateau", required=True)
    ap.add_argument("--stack_x", required=True)
    ap.add_argument("--config", required=True)
    ap.add_argument("--out_json", required=True)
    ap.add_argument("--out_scan", required=True)
    args = ap.parse_args()

    cfg = yaml.safe_load(open(args.config))
    key_frac = str(cfg.get("activation_key", "frac_x_gt"))
    draws = int(cfg.get("mc_draws", 100000))
    seed = int(cfg.get("random_seed", 42))
    min_lenses = int(cfg.get("min_lenses", 0))
    collin_max = float(cfg.get("collinearity_max", 1.0))

    P = load_plateau(args.plateau)
    Xall = pd.read_csv(args.stack_x)

    rows = []
    for (R_MW, eta), Xt in Xall.groupby(["R_MW_kpc", "eta"]):
        out = model_compare(
            P,
            Xt,
            key_frac=key_frac,
            min_lenses=min_lenses,
            collin_max=collin_max,
        )
        rows.append(dict(R_MW_kpc=float(R_MW), eta=float(eta), **out))

    SCAN = pd.DataFrame(rows).sort_values(["eta", "R_MW_kpc"])
    SCAN.to_csv(args.out_scan, index=False)

    best = None
    if SCAN["dAICc"].notna().any():
        best = SCAN.loc[SCAN["dAICc"].idxmax()]

    mc = mc_out_gt_mid(P, draws=draws, seed=seed)

    OUT = dict(
        plateau_file=args.plateau,
        stack_x_file=args.stack_x,
        settings=dict(
            activation_key=key_frac,
            min_lenses=min_lenses,
            collinearity_max=collin_max,
        ),
        best=dict(
            R_MW_kpc=float(best["R_MW_kpc"]) if best is not None else None,
            eta=float(best["eta"]) if best is not None else None,
            dAICc=float(best["dAICc"]) if best is not None else None,
            slope_vs_frac=float(best["slope_vs_frac"]) if best is not None else None,
            used_mass_bins=list(best["used_mass_bins"]) if best is not None else [],
        ),
        per_mass_out_gt_mid=mc,
    )

    with open(args.out_json, "w") as f:
        json.dump(OUT, f, indent=2)

    print(f"[info] wrote {args.out_json} and {args.out_scan}")
    if best is not None:
        print(
            f"[best] R_MW_kpc={best['R_MW_kpc']}, "
            f"eta={best['eta']}, "
            f"dAICc={best['dAICc']:.3f}, "
            f"slope_vs_frac={best['slope_vs_frac']:.4f}, "
            f"used_bins={best['used_mass_bins']}"
        )
    else:
        print("[warn] no valid mass bins for AIC comparison (likely coverage-limited).")


if __name__ == "__main__":
    main()
